{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Custom Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## In this tutorial, we provide an example of adapting usb to custom dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/Zeus/haoc/miniconda/envs/test_semilearn_031/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from torchvision import transforms\n",
    "from semilearn import get_data_loader, get_net_builder, get_algorithm, get_config, Trainer\n",
    "from semilearn import split_ssl_data, BasicDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Specifiy configs and define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: netstat: not found\n"
     ]
    }
   ],
   "source": [
    "# define configs and create config\n",
    "config = {\n",
    "    'algorithm': 'fixmatch',\n",
    "    'net': 'vit_tiny_patch2_32',\n",
    "    'use_pretrain': True, \n",
    "    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',\n",
    "\n",
    "    # optimization configs\n",
    "    'epoch': 1,  \n",
    "    'num_train_iter': 1000, \n",
    "    'num_eval_iter': 500,   \n",
    "    'num_log_iter': 50,    \n",
    "    'optim': 'AdamW',\n",
    "    'lr': 5e-4,\n",
    "    'layer_decay': 0.5,\n",
    "    'batch_size': 16,\n",
    "    'eval_batch_size': 16,\n",
    "\n",
    "    # dataset configs\n",
    "    'dataset': 'mnist',\n",
    "    'num_labels': 40,\n",
    "    'num_classes': 10,\n",
    "    'img_size': 32,\n",
    "    'crop_ratio': 0.875,\n",
    "    'data_dir': './data',\n",
    "\n",
    "    # algorithm specific configs\n",
    "    'hard_label': True,\n",
    "    'uratio': 2,\n",
    "    'ulb_loss_ratio': 1.0,\n",
    "\n",
    "    # device configs\n",
    "    'gpu': 0,\n",
    "    'world_size': 1,\n",
    "    \"num_workers\": 2,\n",
    "    'distributed': False,\n",
    "}\n",
    "config = get_config(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])\n",
      "Create optimizer and scheduler\n"
     ]
    }
   ],
   "source": [
    "# create model and specify algorithm\n",
    "algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own code\n",
    "data = np.random.randint(0, 255, size=3072 * 1000).reshape((-1, 32, 32, 3))\n",
    "data = np.uint8(data)\n",
    "target = np.random.randint(0, 10, size=1000)\n",
    "lb_data, lb_target, ulb_data, ulb_target = split_ssl_data(config, data, target, 10,\n",
    "                                                          config.num_labels, include_lb_to_ulb=config.include_lb_to_ulb)\n",
    "\n",
    "train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),\n",
    "                                      transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "train_strong_transform = transforms.Compose([transforms.RandomHorizontalFlip(),\n",
    "                                             transforms.RandomCrop(32, padding=int(32 * 0.125), padding_mode='reflect'),\n",
    "                                             transforms.ToTensor(),\n",
    "                                             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "lb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=False)\n",
    "ulb_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, train_transform, is_ulb=True, strong_transform=train_strong_transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own code\n",
    "eval_data = np.random.randint(0, 255, size=3072 * 100).reshape((-1, 32, 32, 3))\n",
    "eval_data = np.uint8(eval_data)\n",
    "eval_target = np.random.randint(0, 10, size=100)\n",
    "\n",
    "eval_transform = transforms.Compose([transforms.Resize(32),\n",
    "                                      transforms.ToTensor(),\n",
    "                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "eval_dataset = BasicDataset(config.algorithm, lb_data, lb_target, config.num_classes, eval_transform, is_ulb=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# define data loaders\n",
    "train_lb_loader = get_data_loader(config, lb_dataset, config.batch_size)\n",
    "train_ulb_loader = get_data_loader(config, ulb_dataset, int(config.batch_size * config.uratio))\n",
    "eval_loader = get_data_loader(config, eval_dataset, config.eval_batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Training and evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n",
      "50 iteration USE_EMA: True, train/sup_loss: 2.3674, train/unsup_loss: 0.0000, train/total_loss: 2.3674, train/util_ratio: 0.0000, train/run_time: 0.1369, lr: 0.0005, train/prefecth_time: 0.0040 \n",
      "100 iteration USE_EMA: True, train/sup_loss: 2.3490, train/unsup_loss: 0.0000, train/total_loss: 2.3490, train/util_ratio: 0.0000, train/run_time: 0.1373, lr: 0.0005, train/prefecth_time: 0.0038 \n",
      "150 iteration USE_EMA: True, train/sup_loss: 2.3455, train/unsup_loss: 0.0000, train/total_loss: 2.3455, train/util_ratio: 0.0000, train/run_time: 0.1377, lr: 0.0005, train/prefecth_time: 0.0037 \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m# training and evaluation\u001b[39;00m\n\u001b[1;32m      2\u001b[0m trainer \u001b[39m=\u001b[39m Trainer(config, algorithm)\n\u001b[0;32m----> 3\u001b[0m trainer\u001b[39m.\u001b[39;49mfit(train_lb_loader, train_ulb_loader, eval_loader)\n\u001b[1;32m      4\u001b[0m trainer\u001b[39m.\u001b[39mevaluate(eval_loader)\n",
      "File \u001b[0;32m/media/Zeus/haoc/Semi-supervised-learning/semilearn/lighting/trainer.py:63\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, train_lb_loader, train_ulb_loader, eval_loader)\u001b[0m\n\u001b[1;32m     61\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malgorithm\u001b[39m.\u001b[39mout_dict \u001b[39m=\u001b[39m out_dict\n\u001b[1;32m     62\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malgorithm\u001b[39m.\u001b[39mlog_dict \u001b[39m=\u001b[39m log_dict\n\u001b[0;32m---> 63\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malgorithm\u001b[39m.\u001b[39;49mcall_hook(\u001b[39m\"\u001b[39;49m\u001b[39mafter_train_step\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m     65\u001b[0m bar\u001b[39m.\u001b[39msuffix \u001b[39m=\u001b[39m (\u001b[39m\"\u001b[39m\u001b[39mIter: \u001b[39m\u001b[39m{batch:4}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{iter:4}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(batch\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39malgorithm\u001b[39m.\u001b[39mit, \u001b[39miter\u001b[39m\u001b[39m=\u001b[39m\u001b[39mlen\u001b[39m(train_lb_loader)))\n\u001b[1;32m     66\u001b[0m bar\u001b[39m.\u001b[39mnext()\n",
      "File \u001b[0;32m/media/Zeus/haoc/Semi-supervised-learning/semilearn/core/algorithmbase.py:490\u001b[0m, in \u001b[0;36mAlgorithmBase.call_hook\u001b[0;34m(self, fn_name, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m    488\u001b[0m \u001b[39mfor\u001b[39;00m hook \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mhooks_dict\u001b[39m.\u001b[39mvalues():\n\u001b[1;32m    489\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(hook, fn_name):\n\u001b[0;32m--> 490\u001b[0m         \u001b[39mgetattr\u001b[39;49m(hook, fn_name)(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
      "File \u001b[0;32m/media/Zeus/haoc/Semi-supervised-learning/semilearn/core/hooks/param_update.py:33\u001b[0m, in \u001b[0;36mParamUpdateHook.after_train_step\u001b[0;34m(self, algorithm)\u001b[0m\n\u001b[1;32m     31\u001b[0m     algorithm\u001b[39m.\u001b[39mloss_scaler\u001b[39m.\u001b[39mupdate()\n\u001b[1;32m     32\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 33\u001b[0m     loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m     34\u001b[0m     \u001b[39mif\u001b[39;00m (algorithm\u001b[39m.\u001b[39mclip_grad \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m):\n\u001b[1;32m     35\u001b[0m         torch\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mclip_grad_norm_(algorithm\u001b[39m.\u001b[39mmodel\u001b[39m.\u001b[39mparameters(), algorithm\u001b[39m.\u001b[39mclip_grad)\n",
      "File \u001b[0;32m/media/Zeus/haoc/miniconda/envs/test_semilearn_031/lib/python3.9/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m    478\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    479\u001b[0m         Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m    480\u001b[0m         (\u001b[39mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    485\u001b[0m         inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m    486\u001b[0m     )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m    488\u001b[0m     \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m    489\u001b[0m )\n",
      "File \u001b[0;32m/media/Zeus/haoc/miniconda/envs/test_semilearn_031/lib/python3.9/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    195\u001b[0m     retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m    197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward(  \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    201\u001b[0m     tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m    202\u001b[0m     allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "trainer = Trainer(config, algorithm)\n",
    "trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)\n",
    "trainer.evaluate(eval_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  },
  "vscode": {
   "interpreter": {
    "hash": "efd87a861e5021e4a438e5b61d692cea261dd91508182bfdfdb13fb969975ffe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
